import os
import uuid
import types
from dataclasses import asdict, dataclass
from typing import Any, DefaultDict, Dict, List, Optional, Tuple

import bullet_safety_gym  # noqa
import dsrl
import gymnasium as gym  # noqa
import gym as gym_org
import numpy as np
import pyrallis
import torch
from dsrl.infos import DENSITY_CFG
from dsrl.offline_env import OfflineEnvWrapper, wrap_env  # noqa
from fsrl.utils import WandbLogger
from fsrl.utils import TensorboardLogger
from torch.utils.data import DataLoader
from tqdm.auto import trange  # noqa

from examples.configs.context_encoder_configs import ContextEncoderTrainConfig
from osrl.algorithms import State_AE, Action_AE, inverse_dynamics_model, ActionAETrainer, StateAETrainer
from osrl.algorithms import SafetyAwareEncoder, MultiHeadDecoder, ContextEncoderTrainer, SimpleMlpEncoder
from osrl.common import SequenceDataset, TransitionDataset
from osrl.common.exp_util import auto_name, seed_all, load_config_and_model


@pyrallis.wrap()
def train(args: ContextEncoderTrainConfig):
    # update config
    cfg, old_cfg = asdict(args), asdict(ContextEncoderTrainConfig())
    differing_values = {key: cfg[key] for key in cfg.keys() if cfg[key] != old_cfg[key]}
    cfg = asdict(ContextEncoderTrainConfig())
    cfg.update(differing_values)
    args = types.SimpleNamespace(**cfg)

    # setup logger
    default_cfg = asdict(ContextEncoderTrainConfig())
    if args.name is None:
        args.name = auto_name(default_cfg, cfg, args.prefix, args.suffix)
    if args.group is None:
        args.group = "context_encoder"
    if args.logdir is not None:
        args.logdir = os.path.join(args.logdir, args.group, args.name)
    # logger = WandbLogger(cfg, args.project, args.group, args.name, args.logdir)
    logger = TensorboardLogger(args.logdir, log_txt=True, name=args.name)
    logger.save_config(cfg, verbose=args.verbose)

    tasks = ["OfflinePointButton1Gymnasium-v0","OfflinePointButton2Gymnasium-v0","OfflinePointCircle1Gymnasium-v0","OfflinePointCircle2Gymnasium-v0",
                  "OfflinePointGoal1Gymnasium-v0","OfflinePointGoal2Gymnasium-v0","OfflinePointPush1Gymnasium-v0","OfflinePointPush2Gymnasium-v0",
                  "OfflineHalfCheetahVelocityGymnasium-v0","OfflineHalfCheetahVelocityGymnasium-v1","OfflineHopperVelocityGymnasium-v0","OfflineHopperVelocityGymnasium-v1",
                  "OfflineCarButton1Gymnasium-v0","OfflineCarButton2Gymnasium-v0","OfflineCarCircle1Gymnasium-v0","OfflineCarCircle2Gymnasium-v0",
                  "OfflineCarGoal1Gymnasium-v0","OfflineCarGoal2Gymnasium-v0","OfflineCarPush1Gymnasium-v0","OfflineCarPush2Gymnasium-v0",
                  "OfflineAntVelocityGymnasium-v0","OfflineAntVelocityGymnasium-v1","OfflineSwimmerVelocityGymnasium-v0","OfflineSwimmerVelocityGymnasium-v1",
                  "OfflineWalker2dVelocityGymnasium-v0","OfflineWalker2dVelocityGymnasium-v1"]
    task_names = ["PointButton1","PointButton2","PointCircle1","PointCircle2","PointGoal1","PointGoal2","PointPush1","PointPush2",
                "HalfCheetahVel-v0","HalfCheetahVel-v1","HopperVel-v0","HopperVel-v1",
                "CarButton1","CarButton2","CarCircle1","CarCircle2","CarGoal1","CarGoal2","CarPush1","CarPush2",
                "AntVel-v0","AntVel-v1","SwimmerVel-v0","SwimmerVel-v1","Walker2dVel-v0","Walker2dVel-v1"]
    episode_lens = [1000,1000,500,500,1000,1000,1000,1000,1000,1000,1000,1000,
                    1000,1000,500,500,1000,1000,1000,1000,1000,1000,1000,1000,1000,1000]
    state_dims = [76,76,28,28,60,60,76,76,17,17,11,11,88,88,40,40,72,72,88,88,27,27,8,8,17,17]
    action_dims = [2,2,2,2,2,2,2,2,6,6,3,3,2,2,2,2,2,2,2,2,8,8,2,2,6,6]
    state_encoder_paths = args.state_encoder_paths
    action_encoder_paths = args.action_encoder_paths
    # for debug
    # tasks=[tasks[-1]]
    # task_names=[task_names[-1]]
    # episode_lens=[episode_lens[-1]]
    # state_dims=[state_dims[-1]]
    # action_dims=[action_dims[-1]]
    # state_encoder_paths=[state_encoder_paths[-1]]
    # action_encoder_paths=[action_encoder_paths[-1]]

        
    # set seed
    seed_all(args.seed)
    if args.device == "cpu":
        torch.set_num_threads(args.threads)
    
    state_encoder_ls=[]
    action_encoder_ls=[]
    for i in range(len(tasks)):
        senc_cfg, senc_model = load_config_and_model(state_encoder_paths[i], True)
        state_encoder = State_AE(
            state_dim=state_dims[i],
            encode_dim=senc_cfg["state_encode_dim"],
            hidden_sizes=senc_cfg["state_encoder_hidden_sizes"]
        )
        state_encoder.load_state_dict(senc_model["model_state"])
        state_encoder.eval()
        state_encoder_ls.append(state_encoder)
        if action_encoder_paths[i] is not None:
            aenc_cfg, aenc_model = load_config_and_model(action_encoder_paths[i], True)
            action_encoder = Action_AE(
                action_dim=action_dims[i],
                encode_dim=aenc_cfg["action_encode_dim"],
                hidden_sizes=aenc_cfg["action_encoder_hidden_sizes"]
            )
            action_encoder.load_state_dict(aenc_model["model_state"])
            action_encoder.eval()
            action_encoder_ls.append(action_encoder)
        else:
            action_encoder_ls.append(None)

    # initialize environment
    data_ls=[]
    env_ls=[]
    for task in tasks:
        temp_env = gym.make(task)
        env_ls.append(temp_env)
        temp_data = temp_env.get_dataset()
        data_ls.append(temp_data)
    

    # pre-process offline dataset

    cbins, rbins, max_npb, min_npb = None, None, None, None
    if args.density != 1.0:
        assert False
        density_cfg = DENSITY_CFG[args.task + "_density" + str(args.density)]
        cbins = density_cfg["cbins"]
        rbins = density_cfg["rbins"]
        max_npb = density_cfg["max_npb"]
        min_npb = density_cfg["min_npb"]
    for i in range(len(tasks)):
        data_ls[i] = env_ls[i].pre_process_data(data_ls[i],
                                    args.outliers_percent,
                                    args.noise_scale,
                                    args.inpaint_ranges,
                                    args.epsilon,
                                    args.density,
                                    cbins=cbins,
                                    rbins=rbins,
                                    max_npb=max_npb,
                                    min_npb=min_npb)
    if not args.simple_mlp:
        encoder=SafetyAwareEncoder(
            args.state_encoding_dim*2+args.action_encoding_dim+1,
            args.context_encoder_hidden_sizes,
            args.context_encoding_dim,
            simple_gate=args.simple_gate
            ).to(args.device)
    else:
        encoder=SimpleMlpEncoder(
            args.state_encoding_dim*2+args.action_encoding_dim+2,
            args.context_encoder_hidden_sizes,
            args.context_encoding_dim
            ).to(args.device)
    decoder=MultiHeadDecoder(
        args.context_encoding_dim+args.state_encoding_dim+args.action_encoding_dim,
        args.feature_hidden_sizes,
        args.decoder_hidden_sizes,
        args.state_encoding_dim
        ).to(args.device)
    
    def checkpoint_fn():
        return {"encoder_state": encoder.state_dict(),"decoder_state":decoder.state_dict()}
    
    print(f"Total parameters: {sum(p.numel() for p in encoder.parameters())+sum(p.numel() for p in decoder.parameters())}")

    # def checkpoint_fn():
    #     return {"model_state": model.state_dict()}

    logger.setup_checkpoint_fn(checkpoint_fn)

    # trainer
    trainer = ContextEncoderTrainer(
        encoder,
        decoder,
        logger,
        learning_rate=args.learning_rate,
        betas=args.betas,
        decay_step=args.decay_step,
        decay_rate=args.decay_rate,
        min_learning_rate=args.min_learning_rate,
        state_loss_weight=args.state_loss_weight,
        reward_loss_weight=args.reward_loss_weight,
        cost_loss_weight=args.cost_loss_weight,
        device=args.device)

    dataloader_iter_ls=[]
    for i in range(len(tasks)):
        temp_dataset = TransitionDataset(data_ls[i],
                                        reward_scale=args.reward_scale,
                                        cost_scale=args.cost_scale,
                                        state_encoder=state_encoder_ls[i],
                                        action_encoder=action_encoder_ls[i]
                                        )
        temp_train_loader = DataLoader(
                                temp_dataset,
                                batch_size=args.context_size,
                                pin_memory=True,
                                num_workers=args.num_workers,
                            )
        trainloader_iter = iter(temp_train_loader)
        dataloader_iter_ls.append(trainloader_iter)

    for step in trange(args.update_steps, desc="Training"):
        # train
        state_ls=[]
        action_ls=[]
        next_state_ls=[]
        reward_ls=[]
        cost_ls=[]
        for task_id in range(len(tasks)):
            for _ in range(args.meta_batch_per_task):
                all_unsafe=True
                while all_unsafe:
                    batch=next(dataloader_iter_ls[task_id])
                    observations, next_observations, actions, rewards, costs, done = [
                        b.to(args.device).to(torch.float32) for b in batch
                    ]
                    condition1=costs>0
                    all_unsafe=torch.all(condition1)
                state_ls.append(observations)
                action_ls.append(actions)
                next_state_ls.append(next_observations)
                reward_ls.append(rewards.reshape(-1,1))
                cost_ls.append(costs.reshape(-1,1))
        trainer.train_one_step(state_ls,action_ls,next_state_ls,reward_ls,cost_ls)
        logger.save_checkpoint()
        if (step+1)%args.eval_every==0 or step == args.update_steps - 1:
            logger.write(step, display=False)
            if (step+1)%(args.eval_every*10)==0 or step == args.update_steps - 1:
                vis_nums=args.vis_nums
                num_tasks=len(tasks)
                save_path=args.logdir+"/"+args.name+"/figures/embedding_vis_step"+str(step+1)
                if not os.path.exists(args.logdir+"/"+args.name+"/figures"):
                    os.makedirs(args.logdir+"/"+args.name+"/figures")
                task_ids=[]
                contexts=[]
                cost_ls=[]
                for task_id in range(len(tasks)):
                    for _ in range(vis_nums):
                        task_ids.append(task_id)
                        batch=next(dataloader_iter_ls[task_id])
                        observations, next_observations, actions, rewards, costs, done = [
                            b.to(args.device).to(torch.float32) for b in batch
                        ]
                        cost_ls.append(costs.reshape(-1,1))
                        context=torch.cat([observations,actions,next_observations,rewards.reshape(-1,1)],dim=-1)
                        contexts.append(context)
                trainer.vis_sample_embeddings(contexts,cost_ls,task_ids,num_tasks,save_path)
        else:
            logger.write_without_reset(step)


if __name__ == "__main__":
    train()
